{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 数据集预处理\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import csv\n",
    "import json\n",
    "\n",
    "data_file = 'train.jsonl'\n",
    "jsonl_file = 'news_train.jsonl'\n",
    "\n",
    "# 生成JSONL文件\n",
    "messages = []\n",
    "\n",
    "\n",
    "# 读取jsonl文件\n",
    "with open('train.jsonl', 'r') as file:\n",
    "    for line in file:\n",
    "        # 解析每一行的json数据\n",
    "        data = json.loads(line)\n",
    "        context = data[\"text\"]\n",
    "        catagory = data[\"category\"]\n",
    "        label = data[\"output\"]\n",
    "        message={ \"instruction\":\"你是一个文本分类领域的专家，你会接收到一段文本和几个潜在的分类选项，请输出文本内容的正确类型\",\"input\": f'文本:{context},类型选型:{catagory}',\"output\":label}\n",
    "        messages.append(message)\n",
    "\n",
    "# 保存为JSONL文件\n",
    "with open(jsonl_file, 'w', encoding='utf-8') as file:\n",
    "    for message in messages:\n",
    "        file.write(json.dumps(message, ensure_ascii=False) + '\\n')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import Dataset\n",
    "import pandas as pd\n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq, TrainingArguments, Trainer, GenerationConfig\n",
    "\n",
    "# 将JSON文件转换为CSV文件\n",
    "df = pd.read_json('./news_train.jsonl',lines = True)\n",
    "ds = Dataset.from_pandas(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>instruction</th>\n",
       "      <th>input</th>\n",
       "      <th>output</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>你是一个文本分类领域的专家，你会接收到一段文本和几个潜在的分类选项，请输出文本内容的正确类型</td>\n",
       "      <td>文本:【 文献号 】4-2054\\t【原文出处】史学月刊\\t【原刊地名】开封\\t【原刊期号】...</td>\n",
       "      <td>History</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>你是一个文本分类领域的专家，你会接收到一段文本和几个潜在的分类选项，请输出文本内容的正确类型</td>\n",
       "      <td>文本:【 文献号 】3-2104\\t【原文出处】《幼儿教育》\\t【原刊地名】杭州\\t【原刊期...</td>\n",
       "      <td>Sports</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>你是一个文本分类领域的专家，你会接收到一段文本和几个潜在的分类选项，请输出文本内容的正确类型</td>\n",
       "      <td>文本:中国环境科学CHINA ENVIRONMENTAL SCIENCE1998年 第18卷...</td>\n",
       "      <td>Enviornment</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                      instruction  \\\n",
       "0  你是一个文本分类领域的专家，你会接收到一段文本和几个潜在的分类选项，请输出文本内容的正确类型   \n",
       "1  你是一个文本分类领域的专家，你会接收到一段文本和几个潜在的分类选项，请输出文本内容的正确类型   \n",
       "2  你是一个文本分类领域的专家，你会接收到一段文本和几个潜在的分类选项，请输出文本内容的正确类型   \n",
       "\n",
       "                                               input       output  \n",
       "0  文本:【 文献号 】4-2054\\t【原文出处】史学月刊\\t【原刊地名】开封\\t【原刊期号】...      History  \n",
       "1  文本:【 文献号 】3-2104\\t【原文出处】《幼儿教育》\\t【原刊地名】杭州\\t【原刊期...       Sports  \n",
       "2  文本:中国环境科学CHINA ENVIRONMENTAL SCIENCE1998年 第18卷...  Enviornment  "
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.head(3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 模型下载"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Downloading: 100%|█████████▉| 3.71G/3.71G [00:18<00:00, 216MB/s]\n",
      "Downloading: 100%|█████████▉| 3.69G/3.69G [00:12<00:00, 329MB/s]\n",
      "Downloading: 100%|█████████▉| 3.69G/3.69G [00:18<00:00, 216MB/s] \n",
      "Downloading: 100%|█████████▉| 3.30G/3.30G [00:14<00:00, 239MB/s]\n",
      "Downloading: 100%|██████████| 31.0k/31.0k [00:00<00:00, 2.44MB/s]\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from modelscope import snapshot_download, AutoModel, AutoTokenizer\n",
    "import os\n",
    "# model_dir = snapshot_download('qwen/Qwen1.5-7B-Chat', cache_dir='./', revision='master')\n",
    "model_dir = snapshot_download('qwen/Qwen1.5-7B-Chat', revision='master')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Downloading: 100%|██████████| 663/663 [00:00<00:00, 131kB/s]\n",
      "Downloading: 100%|██████████| 51.0/51.0 [00:00<00:00, 36.0kB/s]\n",
      "Downloading: 100%|██████████| 243/243 [00:00<00:00, 195kB/s]\n",
      "Downloading: 100%|██████████| 6.73k/6.73k [00:00<00:00, 4.55MB/s]\n",
      "Downloading: 100%|██████████| 1.59M/1.59M [00:00<00:00, 30.6MB/s]\n",
      "Downloading: 100%|██████████| 4.15k/4.15k [00:00<00:00, 3.42MB/s]\n",
      "Downloading: 100%|██████████| 6.70M/6.70M [00:00<00:00, 90.2MB/s]\n",
      "Downloading: 100%|██████████| 1.26k/1.26k [00:00<00:00, 1.00MB/s]\n",
      "Downloading: 100%|██████████| 2.65M/2.65M [00:00<00:00, 47.8MB/s]\n",
      "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "Qwen2Tokenizer(name_or_path='/home/jeeves/.cache/modelscope/hub/./qwen', vocab_size=151643, model_max_length=32768, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'eos_token': '<|im_end|>', 'pad_token': '<|endoftext|>', 'additional_special_tokens': ['<|im_start|>', '<|im_end|>']}, clean_up_tokenization_spaces=False),  added_tokens_decoder={\n",
       "\t151643: AddedToken(\"<|endoftext|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
       "\t151644: AddedToken(\"<|im_start|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
       "\t151645: AddedToken(\"<|im_end|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
       "}"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained('./qwen/Qwen1.5-7B-Chat', use_fast=False, trust_remote_code=True)\n",
    "tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def process_func(example):\n",
    "    MAX_LENGTH = 384    # Llama分词器会将一个中文字切分为多个token，因此需要放开一些最大长度，保证数据的完整性\n",
    "    input_ids, attention_mask, labels = [], [], []\n",
    "    instruction = tokenizer(f\"<|im_start|>system\\n你是一个文本分类领域的专家，你会接收到一段文本和几个潜在的分类选项，请输出文本内容的正确类型<|im_end|>\\n<|im_start|>user\\n{example['input']}<|im_end|>\\n<|im_start|>assistant\\n\", add_special_tokens=False)  # add_special_tokens 不在开头加 special_tokens\n",
    "    response = tokenizer(f\"{example['output']}\", add_special_tokens=False)\n",
    "    input_ids = instruction[\"input_ids\"] + response[\"input_ids\"] + [tokenizer.pad_token_id]\n",
    "    attention_mask = instruction[\"attention_mask\"] + response[\"attention_mask\"] + [1]  # 因为eos token咱们也是要关注的所以 补充为1\n",
    "    labels = [-100] * len(instruction[\"input_ids\"]) + response[\"input_ids\"] + [tokenizer.pad_token_id]  \n",
    "    if len(input_ids) > MAX_LENGTH:  # 做一个截断\n",
    "        input_ids = input_ids[:MAX_LENGTH]\n",
    "        attention_mask = attention_mask[:MAX_LENGTH]\n",
    "        labels = labels[:MAX_LENGTH]\n",
    "    return {\n",
    "        \"input_ids\": input_ids,\n",
    "        \"attention_mask\": attention_mask,\n",
    "        \"labels\": labels\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Map: 100%|██████████| 4000/4000 [06:59<00:00,  9.52 examples/s]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['input_ids', 'attention_mask', 'labels'],\n",
       "    num_rows: 4000\n",
       "})"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenized_id = ds.map(process_func, remove_columns=ds.column_names)\n",
    "tokenized_id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading checkpoint shards: 100%|██████████| 4/4 [00:04<00:00,  1.09s/it]\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "model = AutoModelForCausalLM.from_pretrained('qwen/Qwen1.5-7B-Chat', device_map=\"auto\",torch_dtype=torch.bfloat16)\n",
    "\n",
    "\n",
    "model.enable_input_require_grads() # 开启梯度检查点时，要执行该方法"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 加载LoRA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "LoraConfig(peft_type=<PeftType.LORA: 'LORA'>, auto_mapping=None, base_model_name_or_path=None, revision=None, task_type=<TaskType.CAUSAL_LM: 'CAUSAL_LM'>, inference_mode=False, r=8, target_modules={'q_proj', 'k_proj', 'v_proj', 'up_proj', 'down_proj', 'gate_proj', 'o_proj'}, lora_alpha=32, lora_dropout=0.1, fan_in_fan_out=False, bias='none', use_rslora=False, modules_to_save=None, init_lora_weights=True, layers_to_transform=None, layers_pattern=None, rank_pattern={}, alpha_pattern={}, megatron_config=None, megatron_core='megatron.core', loftq_config={}, use_dora=False, layer_replication=None)"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from peft import LoraConfig, TaskType, get_peft_model\n",
    "\n",
    "config = LoraConfig(\n",
    "    task_type=TaskType.CAUSAL_LM, \n",
    "    target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"],\n",
    "    inference_mode=False, # 训练模式\n",
    "    r=8, # Lora 秩\n",
    "    lora_alpha=32, # Lora alaph，具体作用参见 Lora 原理\n",
    "    lora_dropout=0.1# Dropout 比例\n",
    ")\n",
    "config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "LoraConfig(peft_type=<PeftType.LORA: 'LORA'>, auto_mapping=None, base_model_name_or_path='/home/jeeves/.cache/modelscope/hub/qwen/Qwen1___5-7B-Chat', revision=None, task_type=<TaskType.CAUSAL_LM: 'CAUSAL_LM'>, inference_mode=False, r=8, target_modules={'q_proj', 'k_proj', 'v_proj', 'up_proj', 'down_proj', 'gate_proj', 'o_proj'}, lora_alpha=32, lora_dropout=0.1, fan_in_fan_out=False, bias='none', use_rslora=False, modules_to_save=None, init_lora_weights=True, layers_to_transform=None, layers_pattern=None, rank_pattern={}, alpha_pattern={}, megatron_config=None, megatron_core='megatron.core', loftq_config={}, use_dora=False, layer_replication=None)"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = get_peft_model(model, config)\n",
    "config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "trainable params: 19,988,480 || all params: 7,741,313,024 || trainable%: 0.2582\n"
     ]
    }
   ],
   "source": [
    "model.print_trainable_parameters()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 配置训练参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "args = TrainingArguments(\n",
    "    output_dir=\"./output/Qwen1.5\",\n",
    "    per_device_train_batch_size=4,\n",
    "    gradient_accumulation_steps=4,\n",
    "    logging_steps=10,\n",
    "    num_train_epochs=3,\n",
    "    save_steps=100,\n",
    "    learning_rate=1e-4,\n",
    "    save_on_each_node=True,\n",
    "    gradient_checkpointing=True\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 加载SwanLab(pip install swanlab)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n"
     ]
    }
   ],
   "source": [
    "# from swanlab.integration.huggingface import SwanLabCallback\n",
    "\n",
    "# swanlab_callback = SwanLabCallback(project=\"hf-visualization\")\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=args,\n",
    "    train_dataset=tokenized_id,\n",
    "    data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),\n",
    "    # callbacks=[swanlab_callback],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...\n",
      "/home/jeeves/.conda/envs/quen_env/lib/python3.9/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='750' max='750' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [750/750 25:04, Epoch 3/3]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>10</td>\n",
       "      <td>87.662800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>20</td>\n",
       "      <td>0.013800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>30</td>\n",
       "      <td>17.218700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>40</td>\n",
       "      <td>2.142600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>50</td>\n",
       "      <td>0.303900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>60</td>\n",
       "      <td>10.023100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>70</td>\n",
       "      <td>3.488000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>80</td>\n",
       "      <td>1.142900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>90</td>\n",
       "      <td>1.989100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>100</td>\n",
       "      <td>2.739200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>110</td>\n",
       "      <td>5.219000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>120</td>\n",
       "      <td>0.611300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>130</td>\n",
       "      <td>0.031000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>140</td>\n",
       "      <td>1.934900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>150</td>\n",
       "      <td>0.618300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>160</td>\n",
       "      <td>0.035500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>170</td>\n",
       "      <td>0.672600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>180</td>\n",
       "      <td>0.423500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>190</td>\n",
       "      <td>0.004700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>200</td>\n",
       "      <td>3.945700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>210</td>\n",
       "      <td>0.728000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>220</td>\n",
       "      <td>0.002700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>230</td>\n",
       "      <td>0.014700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>240</td>\n",
       "      <td>3.243300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>250</td>\n",
       "      <td>2.858700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>260</td>\n",
       "      <td>4.517300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>270</td>\n",
       "      <td>0.097300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>280</td>\n",
       "      <td>0.094300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>290</td>\n",
       "      <td>0.038400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>300</td>\n",
       "      <td>0.026000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>310</td>\n",
       "      <td>0.036600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>320</td>\n",
       "      <td>3.330700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>330</td>\n",
       "      <td>0.126200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>340</td>\n",
       "      <td>0.001700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>350</td>\n",
       "      <td>0.030200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>360</td>\n",
       "      <td>0.130100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>370</td>\n",
       "      <td>0.031800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>380</td>\n",
       "      <td>0.072000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>390</td>\n",
       "      <td>0.010100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>400</td>\n",
       "      <td>0.731800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>410</td>\n",
       "      <td>0.053600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>420</td>\n",
       "      <td>0.007100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>430</td>\n",
       "      <td>0.029400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>440</td>\n",
       "      <td>0.016700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>450</td>\n",
       "      <td>0.004400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>460</td>\n",
       "      <td>0.044200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>470</td>\n",
       "      <td>0.025300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>480</td>\n",
       "      <td>0.001200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>490</td>\n",
       "      <td>0.029800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>500</td>\n",
       "      <td>0.001900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>510</td>\n",
       "      <td>0.020900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>520</td>\n",
       "      <td>0.025100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>530</td>\n",
       "      <td>0.001500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>540</td>\n",
       "      <td>0.000600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>550</td>\n",
       "      <td>0.119200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>560</td>\n",
       "      <td>0.003800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>570</td>\n",
       "      <td>0.003700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>580</td>\n",
       "      <td>0.000800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>590</td>\n",
       "      <td>0.004900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>600</td>\n",
       "      <td>0.002300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>610</td>\n",
       "      <td>0.031600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>620</td>\n",
       "      <td>0.052600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>630</td>\n",
       "      <td>0.005000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>640</td>\n",
       "      <td>0.002200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>650</td>\n",
       "      <td>0.004300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>660</td>\n",
       "      <td>0.001100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>670</td>\n",
       "      <td>0.005600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>680</td>\n",
       "      <td>0.003500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>690</td>\n",
       "      <td>0.001400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>700</td>\n",
       "      <td>0.009800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>710</td>\n",
       "      <td>0.099900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>720</td>\n",
       "      <td>0.025300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>730</td>\n",
       "      <td>0.001900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>740</td>\n",
       "      <td>0.001600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>750</td>\n",
       "      <td>0.004700</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/jeeves/.conda/envs/quen_env/lib/python3.9/site-packages/peft/utils/save_and_load.py:195: UserWarning: Could not find a config file in /home/jeeves/.cache/modelscope/hub/qwen/Qwen1___5-7B-Chat - will assume that the vocabulary was not modified.\n",
      "  warnings.warn(\n",
      "/home/jeeves/.conda/envs/quen_env/lib/python3.9/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  warnings.warn(\n",
      "/home/jeeves/.conda/envs/quen_env/lib/python3.9/site-packages/peft/utils/save_and_load.py:195: UserWarning: Could not find a config file in /home/jeeves/.cache/modelscope/hub/qwen/Qwen1___5-7B-Chat - will assume that the vocabulary was not modified.\n",
      "  warnings.warn(\n",
      "/home/jeeves/.conda/envs/quen_env/lib/python3.9/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  warnings.warn(\n",
      "/home/jeeves/.conda/envs/quen_env/lib/python3.9/site-packages/peft/utils/save_and_load.py:195: UserWarning: Could not find a config file in /home/jeeves/.cache/modelscope/hub/qwen/Qwen1___5-7B-Chat - will assume that the vocabulary was not modified.\n",
      "  warnings.warn(\n",
      "/home/jeeves/.conda/envs/quen_env/lib/python3.9/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  warnings.warn(\n",
      "/home/jeeves/.conda/envs/quen_env/lib/python3.9/site-packages/peft/utils/save_and_load.py:195: UserWarning: Could not find a config file in /home/jeeves/.cache/modelscope/hub/qwen/Qwen1___5-7B-Chat - will assume that the vocabulary was not modified.\n",
      "  warnings.warn(\n",
      "/home/jeeves/.conda/envs/quen_env/lib/python3.9/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  warnings.warn(\n",
      "/home/jeeves/.conda/envs/quen_env/lib/python3.9/site-packages/peft/utils/save_and_load.py:195: UserWarning: Could not find a config file in /home/jeeves/.cache/modelscope/hub/qwen/Qwen1___5-7B-Chat - will assume that the vocabulary was not modified.\n",
      "  warnings.warn(\n",
      "/home/jeeves/.conda/envs/quen_env/lib/python3.9/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  warnings.warn(\n",
      "/home/jeeves/.conda/envs/quen_env/lib/python3.9/site-packages/peft/utils/save_and_load.py:195: UserWarning: Could not find a config file in /home/jeeves/.cache/modelscope/hub/qwen/Qwen1___5-7B-Chat - will assume that the vocabulary was not modified.\n",
      "  warnings.warn(\n",
      "/home/jeeves/.conda/envs/quen_env/lib/python3.9/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  warnings.warn(\n",
      "/home/jeeves/.conda/envs/quen_env/lib/python3.9/site-packages/peft/utils/save_and_load.py:195: UserWarning: Could not find a config file in /home/jeeves/.cache/modelscope/hub/qwen/Qwen1___5-7B-Chat - will assume that the vocabulary was not modified.\n",
      "  warnings.warn(\n",
      "/home/jeeves/.conda/envs/quen_env/lib/python3.9/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=750, training_loss=2.093196896828711, metrics={'train_runtime': 1506.9786, 'train_samples_per_second': 7.963, 'train_steps_per_second': 0.498, 'total_flos': 1.96825646628864e+17, 'train_loss': 2.093196896828711, 'epoch': 3.0})"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "quen_env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
